package proxy

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"math/rand"
	"mime/multipart"
	"net/http"
	"net/http/httptest"
	"strconv"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/mostlygeek/llama-swap/event"
	"github.com/mostlygeek/llama-swap/proxy/config"
	"github.com/stretchr/testify/assert"
	"github.com/tidwall/gjson"
)

// TestResponseRecorder adds CloseNotify to httptest.ResponseRecorder.
// "If you want to write your own tests around streams you will need a Recorder that can handle CloseNotifier."
// The tests can panic otherwise:
// panic: interface conversion: *httptest.ResponseRecorder is not http.CloseNotifier: missing method CloseNotify
// See: https://github.com/gin-gonic/gin/issues/1815
// TestResponseRecorder is taken from gin's own tests: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765
type TestResponseRecorder struct {
	*httptest.ResponseRecorder
	closeChannel chan bool
}

func (r *TestResponseRecorder) CloseNotify() <-chan bool {
	return r.closeChannel
}

func (r *TestResponseRecorder) closeClient() {
	r.closeChannel <- true
}

func CreateTestResponseRecorder() *TestResponseRecorder {
	return &TestResponseRecorder{
		httptest.NewRecorder(),
		make(chan bool, 1),
	}
}

func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
			"model2": getTestSimpleResponderConfig("model2"),
		},
		LogLevel: "error",
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	for _, modelName := range []string{"model1", "model2"} {
		reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
		req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
		w := CreateTestResponseRecorder()

		proxy.ServeHTTP(w, req)
		assert.Equal(t, http.StatusOK, w.Code)
		assert.Contains(t, w.Body.String(), modelName)
	}
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
			"model2": getTestSimpleResponderConfig("model2"),
		},
		LogLevel: "error",
		Groups: map[string]config.GroupConfig{
			"G1": {
				Swap:      true,
				Exclusive: false,
				Members:   []string{"model1"},
			},
			"G2": {
				Swap:      true,
				Exclusive: false,
				Members:   []string{"model2"},
			},
		},
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	tests := []string{"model1", "model2"}
	for _, requestedModel := range tests {
		t.Run(requestedModel, func(t *testing.T) {
			reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
			req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
			w := CreateTestResponseRecorder()

			proxy.ServeHTTP(w, req)
			assert.Equal(t, http.StatusOK, w.Code)
			assert.Contains(t, w.Body.String(), requestedModel)
		})
	}

	// make sure there's two loaded models
	assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
	assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
}

// Test that a persistent group is not affected by the swapping behaviour of
// other groups.
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
			"model2": getTestSimpleResponderConfig("model2"),
		},
		LogLevel: "error",
		Groups: map[string]config.GroupConfig{
			// the forever group is persistent and should not be affected by model1
			"forever": {
				Swap:       true,
				Exclusive:  false,
				Persistent: true,
				Members:    []string{"model2"},
			},
		},
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	// make requests to load all models, loading model1 should not affect model2
	tests := []string{"model2", "model1"}
	for _, requestedModel := range tests {
		reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
		req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
		w := CreateTestResponseRecorder()

		proxy.ServeHTTP(w, req)
		assert.Equal(t, http.StatusOK, w.Code)
		assert.Contains(t, w.Body.String(), requestedModel)
	}

	assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
	assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
}

// When a request for a different model comes in ProxyManager should wait until
// the first request is complete before swapping. Both requests should complete
func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
	if testing.Short() {
		t.Skip("skipping slow test")
	}

	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
			"model2": getTestSimpleResponderConfig("model2"),
			"model3": getTestSimpleResponderConfig("model3"),
		},
		LogLevel: "error",
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	results := map[string]string{}

	var wg sync.WaitGroup
	var mu sync.Mutex

	for key := range config.Models {
		wg.Add(1)
		go func(key string) {
			defer wg.Done()

			reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
			req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
			w := CreateTestResponseRecorder()

			proxy.ServeHTTP(w, req)

			if w.Code != http.StatusOK {
				t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
			}

			mu.Lock()
			var response map[string]interface{}
			assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
			result, ok := response["responseMessage"].(string)
			assert.Equal(t, ok, true)
			results[key] = result
			mu.Unlock()
		}(key)

		<-time.After(time.Millisecond)
	}

	wg.Wait()
	assert.Len(t, results, len(config.Models))

	for key, result := range results {
		assert.Equal(t, key, result)
	}
}

func TestProxyManager_ListModelsHandler(t *testing.T) {

	model1Config := getTestSimpleResponderConfig("model1")
	model1Config.Name = "Model 1"
	model1Config.Description = "Model 1 description is used for testing"

	model2Config := getTestSimpleResponderConfig("model2")
	model2Config.Name = "     " // empty whitespace only strings will get ignored
	model2Config.Description = "  "

	config := config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": model1Config,
			"model2": model2Config,
			"model3": getTestSimpleResponderConfig("model3"),
		},
		LogLevel: "error",
	}

	proxy := New(config)

	// Create a test request
	req := httptest.NewRequest("GET", "/v1/models", nil)
	req.Header.Add("Origin", "i-am-the-origin")
	w := CreateTestResponseRecorder()

	// Call the listModelsHandler
	proxy.ServeHTTP(w, req)

	// Check the response status code
	assert.Equal(t, http.StatusOK, w.Code)

	// Check for Access-Control-Allow-Origin
	assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))

	// Parse the JSON response
	var response struct {
		Data []map[string]interface{} `json:"data"`
	}

	if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
		t.Fatalf("Failed to parse JSON response: %v", err)
	}

	// Check the number of models returned
	assert.Len(t, response.Data, 3)

	// Check the details of each model
	expectedModels := map[string]struct{}{
		"model1": {},
		"model2": {},
		"model3": {},
	}

	// make all models
	for _, model := range response.Data {
		modelID, ok := model["id"].(string)
		assert.True(t, ok, "model ID should be a string")
		_, exists := expectedModels[modelID]
		assert.True(t, exists, "unexpected model ID: %s", modelID)
		delete(expectedModels, modelID)

		object, ok := model["object"].(string)
		assert.True(t, ok, "object should be a string")
		assert.Equal(t, "model", object)

		created, ok := model["created"].(float64)
		assert.True(t, ok, "created should be a number")
		assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive

		ownedBy, ok := model["owned_by"].(string)
		assert.True(t, ok, "owned_by should be a string")
		assert.Equal(t, "llama-swap", ownedBy)

		// check for optional name and description
		if modelID == "model1" {
			name, ok := model["name"].(string)
			assert.True(t, ok, "name should be a string")
			assert.Equal(t, "Model 1", name)
			description, ok := model["description"].(string)
			assert.True(t, ok, "description should be a string")
			assert.Equal(t, "Model 1 description is used for testing", description)
		} else {
			_, exists := model["name"]
			assert.False(t, exists, "unexpected name field for model: %s", modelID)
			_, exists = model["description"]
			assert.False(t, exists, "unexpected description field for model: %s", modelID)
		}
	}

	// Ensure all expected models were returned
	assert.Empty(t, expectedModels, "not all expected models were returned")
}

func TestProxyManager_ListModelsHandler_WithMetadata(t *testing.T) {
	// Process config through LoadConfigFromReader to apply macro substitution
	configYaml := `
healthCheckTimeout: 15
logLevel: error
startPort: 10000
models:
  model1:
    cmd: /path/to/server -p ${PORT}
    macros:
      PORT_NUM: 10001
      TEMP: 0.7
      NAME: "llama"
    metadata:
      port: ${PORT_NUM}
      temperature: ${TEMP}
      enabled: true
      note: "Running on port ${PORT_NUM}"
      nested:
        value: ${TEMP}
  model2:
    cmd: /path/to/server -p ${PORT}
`
	processedConfig, err := config.LoadConfigFromReader(strings.NewReader(configYaml))
	assert.NoError(t, err)

	proxy := New(processedConfig)

	req := httptest.NewRequest("GET", "/v1/models", nil)
	w := CreateTestResponseRecorder()
	proxy.ServeHTTP(w, req)

	assert.Equal(t, http.StatusOK, w.Code)

	var response struct {
		Data []map[string]any `json:"data"`
	}

	err = json.Unmarshal(w.Body.Bytes(), &response)
	assert.NoError(t, err)
	assert.Len(t, response.Data, 2)

	// Find model1 and model2 in response
	var model1Data, model2Data map[string]any
	for _, model := range response.Data {
		if model["id"] == "model1" {
			model1Data = model
		} else if model["id"] == "model2" {
			model2Data = model
		}
	}

	// Verify model1 has llamaswap_meta
	assert.NotNil(t, model1Data)
	meta, exists := model1Data["meta"]
	if !assert.True(t, exists, "model1 should have meta key") {
		t.FailNow()
	}

	metaMap := meta.(map[string]any)

	lsmeta, exists := metaMap["llamaswap"]
	if !assert.True(t, exists, "model1 should have meta.llamaswap key") {
		t.FailNow()
	}

	lsmetamap := lsmeta.(map[string]any)

	// Verify type preservation
	assert.Equal(t, float64(10001), lsmetamap["port"]) // JSON numbers are float64
	assert.Equal(t, 0.7, lsmetamap["temperature"])
	assert.Equal(t, true, lsmetamap["enabled"])
	// Verify string interpolation
	assert.Equal(t, "Running on port 10001", lsmetamap["note"])
	// Verify nested structure
	nested := lsmetamap["nested"].(map[string]any)
	assert.Equal(t, 0.7, nested["value"])

	// Verify model2 does NOT have llamaswap_meta
	assert.NotNil(t, model2Data)
	_, exists = model2Data["llamaswap_meta"]
	assert.False(t, exists, "model2 should not have llamaswap_meta")
}

func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
	// Intentionally add models in non-sorted order and with an unlisted model
	config := config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"zeta":  getTestSimpleResponderConfig("zeta"),
			"alpha": getTestSimpleResponderConfig("alpha"),
			"beta":  getTestSimpleResponderConfig("beta"),
			"hidden": func() config.ModelConfig {
				mc := getTestSimpleResponderConfig("hidden")
				mc.Unlisted = true
				return mc
			}(),
		},
		LogLevel: "error",
	}

	proxy := New(config)

	// Request models list
	req := httptest.NewRequest("GET", "/v1/models", nil)
	w := CreateTestResponseRecorder()
	proxy.ServeHTTP(w, req)

	assert.Equal(t, http.StatusOK, w.Code)

	var response struct {
		Data []map[string]interface{} `json:"data"`
	}
	if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
		t.Fatalf("Failed to parse JSON response: %v", err)
	}

	// We expect only the listed models in sorted order by id
	expectedOrder := []string{"alpha", "beta", "zeta"}
	if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") {
		got := make([]string, 0, len(response.Data))
		for _, m := range response.Data {
			id, _ := m["id"].(string)
			got = append(got, id)
		}
		assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending")
	}
}

func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
	// Configure alias
	config := config.Config{
		HealthCheckTimeout:   15,
		IncludeAliasesInList: true,
		Models: map[string]config.ModelConfig{
			"model1": func() config.ModelConfig {
				mc := getTestSimpleResponderConfig("model1")
				mc.Name = "Model 1"
				mc.Aliases = []string{"alias1"}
				return mc
			}(),
		},
		LogLevel: "error",
	}

	proxy := New(config)

	// Request models list
	req := httptest.NewRequest("GET", "/v1/models", nil)
	w := CreateTestResponseRecorder()
	proxy.ServeHTTP(w, req)

	assert.Equal(t, http.StatusOK, w.Code)

	var response struct {
		Data []map[string]interface{} `json:"data"`
	}
	if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
		t.Fatalf("Failed to parse JSON response: %v", err)
	}

	// We expect both base id and alias
	var model1Data, alias1Data map[string]any
	for _, model := range response.Data {
		if model["id"] == "model1" {
			model1Data = model
		} else if model["id"] == "alias1" {
			alias1Data = model
		}
	}

	// Verify model1 has name
	assert.NotNil(t, model1Data)
	_, exists := model1Data["name"]
	if !assert.True(t, exists, "model1 should have name key") {
		t.FailNow()
	}
	name1, ok := model1Data["name"].(string)
	assert.True(t, ok, "name1 should be a string")

	// Verify alias1 has name
	assert.NotNil(t, alias1Data)
	_, exists = alias1Data["name"]
	if !assert.True(t, exists, "alias1 should have name key") {
		t.FailNow()
	}
	name2, ok := alias1Data["name"].(string)
	assert.True(t, ok, "name2 should be a string")

	// Name keys should match
	assert.Equal(t, name1, name2)
}

func TestProxyManager_Shutdown(t *testing.T) {
	// make broken model configurations
	model1Config := getTestSimpleResponderConfigPort("model1", 9991)
	model1Config.Proxy = "http://localhost:10001/"

	model2Config := getTestSimpleResponderConfigPort("model2", 9992)
	model2Config.Proxy = "http://localhost:10002/"

	model3Config := getTestSimpleResponderConfigPort("model3", 9993)
	model3Config.Proxy = "http://localhost:10003/"

	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": model1Config,
			"model2": model2Config,
			"model3": model3Config,
		},
		LogLevel: "error",
		Groups: map[string]config.GroupConfig{
			"test": {
				Swap:    false,
				Members: []string{"model1", "model2", "model3"},
			},
		},
	})

	proxy := New(config)

	// Start all the processes
	var wg sync.WaitGroup
	for _, modelName := range []string{"model1", "model2", "model3"} {
		wg.Add(1)
		go func(modelName string) {
			defer wg.Done()
			reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
			req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
			w := CreateTestResponseRecorder()

			// send a request to trigger the proxy to load ... this should hang waiting for start up
			proxy.ServeHTTP(w, req)
			assert.Equal(t, http.StatusBadGateway, w.Code)
			assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
		}(modelName)
	}

	go func() {
		<-time.After(time.Second)
		proxy.Shutdown()
	}()
	wg.Wait()
}

func TestProxyManager_Unload(t *testing.T) {
	conf := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
		},
		LogLevel: "error",
	})

	proxy := New(conf)
	reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
	req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
	w := CreateTestResponseRecorder()
	proxy.ServeHTTP(w, req)

	assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
	req = httptest.NewRequest("GET", "/unload", nil)
	w = CreateTestResponseRecorder()
	proxy.ServeHTTP(w, req)
	assert.Equal(t, http.StatusOK, w.Code)
	assert.Equal(t, w.Body.String(), "OK")

	select {
	case <-proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
		// good
	case <-time.After(2 * time.Second):
		t.Fatal("timeout waiting for model1 to stop")
	}
	assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
}

func TestProxyManager_UnloadSingleModel(t *testing.T) {
	const testGroupId = "testGroup"
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
			"model2": getTestSimpleResponderConfig("model2"),
		},
		Groups: map[string]config.GroupConfig{
			testGroupId: {
				Swap:    false,
				Members: []string{"model1", "model2"},
			},
		},
		LogLevel: "error",
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopImmediately)

	// start both model
	for _, modelName := range []string{"model1", "model2"} {
		reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
		req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
		w := CreateTestResponseRecorder()
		proxy.ServeHTTP(w, req)
	}

	assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState())
	assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())

	req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
	w := CreateTestResponseRecorder()
	proxy.ServeHTTP(w, req)
	assert.Equal(t, http.StatusOK, w.Code)
	if !assert.Equal(t, w.Body.String(), "OK") {
		t.FailNow()
	}

	select {
	case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan:
		// good
	case <-time.After(2 * time.Second):
		t.Fatal("timeout waiting for model1 to stop")
	}

	assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped)
	assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady)
}

// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
	// Shared configuration
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
			"model2": getTestSimpleResponderConfig("model2"),
		},
		LogLevel: "warn",
	})

	// Define a helper struct to parse the JSON response.
	type RunningResponse struct {
		Running []struct {
			Model string `json:"model"`
			State string `json:"state"`
		} `json:"running"`
	}

	// Create proxy once for all tests
	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	t.Run("no models loaded", func(t *testing.T) {
		req := httptest.NewRequest("GET", "/running", nil)
		w := CreateTestResponseRecorder()
		proxy.ServeHTTP(w, req)

		assert.Equal(t, http.StatusOK, w.Code)

		var response RunningResponse

		// Check if this is a valid JSON object.
		assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))

		// We should have an empty running array here.
		assert.Empty(t, response.Running, "expected no running models")
	})

	t.Run("single model loaded", func(t *testing.T) {
		// Load just a model.
		reqBody := `{"model":"model1"}`
		req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
		w := CreateTestResponseRecorder()
		proxy.ServeHTTP(w, req)
		assert.Equal(t, http.StatusOK, w.Code)

		// Simulate browser call for the `/running` endpoint.
		req = httptest.NewRequest("GET", "/running", nil)
		w = CreateTestResponseRecorder()
		proxy.ServeHTTP(w, req)

		var response RunningResponse
		assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))

		// Check if we have a single array element.
		assert.Len(t, response.Running, 1)

		// Is this the right model?
		assert.Equal(t, "model1", response.Running[0].Model)

		// Is the model loaded?
		assert.Equal(t, "ready", response.Running[0].State)
	})
}

func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
		},
		LogLevel: "error",
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	// Create a buffer with multipart form data
	var b bytes.Buffer
	w := multipart.NewWriter(&b)

	// Add the model field
	fw, err := w.CreateFormField("model")
	assert.NoError(t, err)
	_, err = fw.Write([]byte("TheExpectedModel"))
	assert.NoError(t, err)

	// Add a file field
	fw, err = w.CreateFormFile("file", "test.mp3")
	assert.NoError(t, err)
	// Generate random content length between 10 and 20
	contentLength := rand.Intn(11) + 10 // 10 to 20
	content := make([]byte, contentLength)
	_, err = fw.Write(content)
	assert.NoError(t, err)
	w.Close()

	// Create the request with the multipart form data
	req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
	req.Header.Set("Content-Type", w.FormDataContentType())
	rec := CreateTestResponseRecorder()
	proxy.ServeHTTP(rec, req)

	// Verify the response
	assert.Equal(t, http.StatusOK, rec.Code)
	var response map[string]string
	err = json.Unmarshal(rec.Body.Bytes(), &response)
	assert.NoError(t, err)
	assert.Equal(t, "TheExpectedModel", response["model"])
	assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
	assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
}

// Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) {
	upstreamModelName := "upstreamModel"
	modelConfig := getTestSimpleResponderConfig(upstreamModelName)
	modelConfig.UseModelName = upstreamModelName

	conf := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": modelConfig,
		},
		LogLevel: "error",
	})

	proxy := New(conf)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	requestedModel := "model1"

	t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
		reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
		req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
		w := CreateTestResponseRecorder()

		proxy.ServeHTTP(w, req)
		assert.Equal(t, http.StatusOK, w.Code)
		assert.Contains(t, w.Body.String(), upstreamModelName)

		// make sure the content length was set correctly
		// simple-responder will return the content length it got in the response
		body := w.Body.Bytes()
		contentLength := int(gjson.GetBytes(body, "h_content_length").Int())
		assert.Equal(t, len(fmt.Sprintf(`{"model":"%s"}`, upstreamModelName)), contentLength)
	})

	t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
		// Create a buffer with multipart form data
		var b bytes.Buffer
		w := multipart.NewWriter(&b)

		// Add the model field
		fw, err := w.CreateFormField("model")
		assert.NoError(t, err)
		_, err = fw.Write([]byte(requestedModel))
		assert.NoError(t, err)

		// Add a file field
		fw, err = w.CreateFormFile("file", "test.mp3")
		assert.NoError(t, err)
		_, err = fw.Write([]byte("test"))
		assert.NoError(t, err)
		w.Close()

		// Create the request with the multipart form data
		req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
		req.Header.Set("Content-Type", w.FormDataContentType())
		rec := CreateTestResponseRecorder()
		proxy.ServeHTTP(rec, req)

		// Verify the response
		assert.Equal(t, http.StatusOK, rec.Code)
		var response map[string]string
		err = json.Unmarshal(rec.Body.Bytes(), &response)
		assert.NoError(t, err)
		assert.Equal(t, upstreamModelName, response["model"])
	})
}

func TestProxyManager_CORSOptionsHandler(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
		},
		LogLevel: "error",
	})

	tests := []struct {
		name            string
		method          string
		requestHeaders  map[string]string
		expectedStatus  int
		expectedHeaders map[string]string
	}{
		{
			name:           "OPTIONS with no headers",
			method:         "OPTIONS",
			expectedStatus: http.StatusNoContent,
			expectedHeaders: map[string]string{
				"Access-Control-Allow-Origin":  "*",
				"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
				"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
			},
		},
		{
			name:   "OPTIONS with specific headers",
			method: "OPTIONS",
			requestHeaders: map[string]string{
				"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
			},
			expectedStatus: http.StatusNoContent,
			expectedHeaders: map[string]string{
				"Access-Control-Allow-Origin":  "*",
				"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
				"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
			},
		},
		{
			name:           "Non-OPTIONS request",
			method:         "GET",
			expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			proxy := New(config)
			defer proxy.StopProcesses(StopWaitForInflightRequest)

			req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
			for k, v := range tt.requestHeaders {
				req.Header.Set(k, v)
			}

			w := CreateTestResponseRecorder()
			proxy.ServeHTTP(w, req)

			assert.Equal(t, tt.expectedStatus, w.Code)

			for header, expectedValue := range tt.expectedHeaders {
				assert.Equal(t, expectedValue, w.Header().Get(header))
			}
		})
	}
}

func TestProxyManager_Upstream(t *testing.T) {
	configStr := fmt.Sprintf(`
logLevel: error
models:
  model1:
    cmd: %s -port ${PORT} -silent -respond model1
    aliases: [model-alias]
`, getSimpleResponderPath())

	config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
	assert.NoError(t, err)

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)
	t.Run("main model name", func(t *testing.T) {
		req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
		rec := CreateTestResponseRecorder()
		proxy.ServeHTTP(rec, req)
		assert.Equal(t, http.StatusOK, rec.Code)
		assert.Equal(t, "model1", rec.Body.String())
	})

	t.Run("model alias", func(t *testing.T) {
		req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
		rec := CreateTestResponseRecorder()
		proxy.ServeHTTP(rec, req)
		assert.Equal(t, http.StatusOK, rec.Code)
		assert.Equal(t, "model1", rec.Body.String())
	})
}

func TestProxyManager_ChatContentLength(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
		},
		LogLevel: "error",
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
	req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
	w := CreateTestResponseRecorder()

	proxy.ServeHTTP(w, req)
	assert.Equal(t, http.StatusOK, w.Code)
	var response map[string]interface{}
	assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
	assert.Equal(t, "81", response["h_content_length"])
	assert.Equal(t, "model1", response["responseMessage"])
}

func TestProxyManager_FiltersStripParams(t *testing.T) {
	modelConfig := getTestSimpleResponderConfig("model1")
	modelConfig.Filters = config.ModelFilters{
		StripParams: "temperature, model, stream",
	}

	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		LogLevel:           "error",
		Models: map[string]config.ModelConfig{
			"model1": modelConfig,
		},
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)
	reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
	req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
	w := CreateTestResponseRecorder()

	proxy.ServeHTTP(w, req)
	assert.Equal(t, http.StatusOK, w.Code)
	var response map[string]interface{}
	assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))

	// `temperature` and `stream` are gone but model remains
	assert.Equal(t, `{"model":"model1", "x_param":"123", "y_param":"abc"}`, response["request_body"])

	// assert.Nil(t, response["temperature"])
	// assert.Equal(t, "123", response["x_param"])
	// assert.Equal(t, "abc", response["y_param"])
	// t.Logf("%v", response)
}

func TestProxyManager_HealthEndpoint(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
		},
		LogLevel: "error",
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)
	req := httptest.NewRequest("GET", "/health", nil)
	rec := CreateTestResponseRecorder()
	proxy.ServeHTTP(rec, req)
	assert.Equal(t, http.StatusOK, rec.Code)
	assert.Equal(t, "OK", rec.Body.String())
}

// Ensure the custom llama-server /completion endpoint proxies correctly
func TestProxyManager_CompletionEndpoint(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
		},
		LogLevel: "error",
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	reqBody := `{"model":"model1"}`
	req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
	w := CreateTestResponseRecorder()

	proxy.ServeHTTP(w, req)
	assert.Equal(t, http.StatusOK, w.Code)
	assert.Contains(t, w.Body.String(), "model1")
}

func TestProxyManager_StartupHooks(t *testing.T) {

	// using real YAML as the configuration has gotten more complex
	// is the right approach as LoadConfigFromReader() does a lot more
	// than parse YAML now. Eventually migrate all tests to use this approach
	configStr := strings.Replace(`
logLevel: error
hooks:
  on_startup:
    preload:
      - model1
      - model2
groups:
  preloadTestGroup:
    swap: false
    members:
       - model1
       - model2
models:
  model1:
    cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model1
  model2:
      cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model2
`, "${simpleresponderpath}", simpleResponderPath, -1)

	// Create a test model configuration
	config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
	if !assert.NoError(t, err, "Invalid configuration") {
		return
	}

	preloadChan := make(chan ModelPreloadedEvent, 2) // buffer for 2 expected events

	unsub := event.On(func(e ModelPreloadedEvent) {
		preloadChan <- e
	})

	defer unsub()

	// Create the proxy which should trigger preloading
	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	for i := 0; i < 2; i++ {
		select {
		case <-preloadChan:
		case <-time.After(5 * time.Second):
			t.Fatal("timed out waiting for models to preload")
		}
	}
	// make sure they are both loaded
	_, foundGroup := proxy.processGroups["preloadTestGroup"]
	if !assert.True(t, foundGroup, "preloadTestGroup should exist") {
		return
	}
	assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model1"].CurrentState())
	assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model2"].CurrentState())
}

func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
		},
		LogLevel: "error",
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	endpoints := []string{
		"/api/events",
		"/logs/stream",
		"/logs/stream/proxy",
		"/logs/stream/upstream",
	}

	for _, endpoint := range endpoints {
		t.Run(endpoint, func(t *testing.T) {
			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
			defer cancel()

			req := httptest.NewRequest("GET", endpoint, nil)
			req = req.WithContext(ctx)
			rec := CreateTestResponseRecorder()

			// Run handler in goroutine and wait for context timeout
			done := make(chan struct{})
			go func() {
				defer close(done)
				proxy.ServeHTTP(rec, req)
			}()

			// Wait for either the handler to complete or context to timeout
			<-ctx.Done()

			// At this point, the handler has either finished or been cancelled
			// Wait for the goroutine to fully exit before reading
			<-done

			// Now it's safe to read from rec - no more concurrent writes
			assert.Equal(t, http.StatusOK, rec.Code)
			assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
		})
	}
}

func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"streaming-model": getTestSimpleResponderConfig("streaming-model"),
		},
		LogLevel: "error",
	})

	proxy := New(config)
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	// Make a streaming request
	reqBody := `{"model":"streaming-model"}`
	// simple-responder will return text/event-stream when stream=true is in the query
	req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
	rec := CreateTestResponseRecorder()

	proxy.ServeHTTP(rec, req)

	assert.Equal(t, http.StatusOK, rec.Code)
	assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
	assert.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
}

func TestProxyManager_ApiGetVersion(t *testing.T) {
	config := config.AddDefaultGroupToConfig(config.Config{
		HealthCheckTimeout: 15,
		Models: map[string]config.ModelConfig{
			"model1": getTestSimpleResponderConfig("model1"),
		},
		LogLevel: "error",
	})

	// Version test map
	versionTest := map[string]string{
		"build_date": "1970-01-01T00:00:00Z",
		"commit":     "cc915ddb6f04a42d9cd1f524e1d46ec6ed069fdc",
		"version":    "v001",
	}

	proxy := New(config)
	proxy.SetVersion(versionTest["build_date"], versionTest["commit"], versionTest["version"])
	defer proxy.StopProcesses(StopWaitForInflightRequest)

	req := httptest.NewRequest("GET", "/api/version", nil)
	w := CreateTestResponseRecorder()

	proxy.ServeHTTP(w, req)

	assert.Equal(t, http.StatusOK, w.Code)

	// Ensure json response
	assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type"))

	// Check for attributes
	response := map[string]string{}
	assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
	for key, value := range versionTest {
		assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
	}
}
